import json
import os
import csv
from typing import List, Dict
from openai import OpenAI
# from dotenv import load_dotenv

# Load environment variables
# load_dotenv()
client = OpenAI(
    api_key=os.environ.get("OPENAI_API_KEY"),
)

# Different prompts for different task types
SINGLE_TASK_PROMPT = """
Analyze this web agent automation task execution:
Task execution trajectory: {output}
Please provide:
1. A summary of the task execution process
2. An analysis of potential issues encountered
3. Concrete improvement suggestions for future tasks
"""

BATCH_TASK_PROMPT = """
Analyze this specific web agent automation task:
Task execution: {output}
Expected result: {groundtruth}
Please provide:
1. Task completion assessment (Success/Partial/Failed)
2. Gap analysis between execution and expected result
3. Specific improvement suggestions
"""

def analyze_single_task(data: List[Dict]) -> str:
    output = data[0]['output']  # Assuming single task has one output
    
    response = client.chat.completions.create(
        model="gpt-4o",
        messages=[
            {"role": "system", "content": "You are an assistant analyzing web agent automation tasks."},
            {"role": "user", "content": SINGLE_TASK_PROMPT.format(output=output)}
        ]
    )
    return response.choices[0].message.content

def analyze_batch_task(data: List[Dict]) -> List[Dict]:
    results = []
    total_tasks = len(data)
    print(f"\nStarting batch analysis of {total_tasks} tasks...")
    
    for idx, item in enumerate(data, 1):
        task_id = item.get('task_id', str(idx))
        print(f"Processing task {task_id} ({idx}/{total_tasks})...")
        
        try:
            response = client.chat.completions.create(
                model="gpt-4o",
                messages=[
                    {"role": "system", "content": "You are an assistant analyzing web agent automation tasks."},
                    {"role": "user", "content": BATCH_TASK_PROMPT.format(
                        output=item['output'],
                        groundtruth=item['groundtruth']
                    )}
                ]
            )
            results.append({
                'task_id': task_id,
                'output': item['output'],
                'groundtruth': item['groundtruth'],
                'analysis': response.choices[0].message.content
            })
            print(f"✓ Task {task_id} analyzed successfully")
            
        except Exception as e:
            print(f"✗ Error processing task {task_id}: {str(e)}")
            results.append({
                'task_id': task_id,
                'output': item['output'],
                'groundtruth': item['groundtruth'],
                'analysis': f"Error during analysis: {str(e)}"
            })
            
        if idx < total_tasks:
            print("-" * 40)  # Separator between tasks
    
    print(f"\nBatch analysis completed. Processed {total_tasks} tasks.")
    return results

def save_results(results, task_type: str):
    # Save raw response
    with open('./response.json', 'w') as f:
        json.dump(results, f, indent=4)

    if task_type == 'batch_task':
        # Save CSV for batch tasks
        with open('./analysis_results.csv', 'w', newline='') as f:
            writer = csv.DictWriter(f, fieldnames=['task_id', 'output', 'groundtruth', 'analysis'])
            writer.writeheader()
            writer.writerows(results)
    else:
        # Save text file for single task
        with open('./output_content.txt', 'w') as f:
            f.write(results)

def main():
    # Read config
    with open('config.json', 'r') as config_file:
        config = json.load(config_file)

    # Determine task type and file
    json_file_path = 'batch_task_demo.json' if config['task_type'] == 'batch_task' else 'single_task_demo.json'
    
    # Read task data
    with open(json_file_path, 'r') as file:
        data = json.load(file)

    # Process based on task type
    if config['task_type'] == 'batch_task':
        results = analyze_batch_task(data)
    else:
        results = analyze_single_task(data)

    # Save results
    save_results(results, config['task_type'])
    print(f"Results saved for {config['task_type']}")

if __name__ == "__main__":
    main()